import os

os.environ["CUDA_VISIBLE_DEVICES"] = '2,3'
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.sql_database import SQLDatabase

from langchain.llms import VLLM
from langchain import OpenAI, SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.chains import create_sql_query_chain

model_dir = '/datasets/fengjiahao/nlp/TongyiFinance/Tongyi-Finance-14B'

llm = VLLM(
    model=model_dir,
    trust_remote_code=True,  # mandatory for hf models
    temperature=0.2,
    top_p=0.7,
    top_k=10,
    tensor_parallel_size=2, verbose=True
)

db = SQLDatabase.from_uri(
        "sqlite:////datasets/fengjiahao/nlp/bs_challenge_financial_14b_dataset/dataset/博金杯比赛数据.db",
        include_tables=[ 'A股股票行情'],
        view_support=True,
        sample_rows_in_table_info=2)
print(db.table_info)

question = db.table_info+"\n问题是：请帮我计算，在20210105，中信行业分类划分的一级行业为综合金融行业中，涨跌幅最大股票的股票代码是？涨跌幅是多少？百分数保留两位小数。股票涨跌幅定义为：（收盘价 - 前一日收盘价 / 前一日收盘价）* 100%。请写出查询数据库的SQL:"
response = llm(question)
print(response)
print('---------------------------------------------')

# toolkit = SQLDatabaseToolkit(db=db, llm=llm)

# db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, use_query_checker=True, return_intermediate_steps=True)


# response = db_chain.invoke({"question": "请帮我计算，在20210105，中信行业分类划分的一级行业为综合金融行业中，涨跌幅最大股票的股票代码是？涨跌幅是多少？百分数保留两位小数。股票涨跌幅定义为：（收盘价 - 前一日收盘价 / 前一日收盘价）* 100%。"})
# print(response)
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "请帮我计算，在20210105，中信行业分类划分的一级行业为综合金融行业中，涨跌幅最大股票的股票代码是？涨跌幅是多少？百分数保留两位小数。股票涨跌幅定义为：（收盘价 - 前一日收盘价 / 前一日收盘价）* 100%。"})
print(response)
print()


print('---------------------------------------------')
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, use_query_checker=True, return_intermediate_steps=True)
response = db_chain.invoke({"query": "请帮我计算，在20210105，中信行业分类划分的一级行业为综合金融行业中，涨跌幅最大股票的股票代码是？涨跌幅是多少？百分数保留两位小数。股票涨跌幅定义为：（收盘价 - 前一日收盘价 / 前一日收盘价）* 100%。"})
print(response)